Skip to content

🚨 EP: fix EP router contract for many models + honor FP8 scale format#46818

Open
IlyasMoutawwakil wants to merge 28 commits into
mainfrom
fix-glm-dsa
Open

🚨 EP: fix EP router contract for many models + honor FP8 scale format#46818
IlyasMoutawwakil wants to merge 28 commits into
mainfrom
fix-glm-dsa

Conversation

@IlyasMoutawwakil

Copy link
Copy Markdown
Member

What does this PR do?

Fixes # (issue)

Code Agent Policy

The Transformers repo is currently being overwhelmed by a large number of PRs and issue comments written by
code agents. We are currently bottlenecked by our ability to review and respond to them. As a result,
we ask that new users do not submit pure code agent PRs at this time.
You may use code agents in drafting or to help you diagnose issues. We'd also ask autonomous "OpenClaw"-like agents
not to open any PRs or issues for the moment.

PRs that appear to be fully agent-written will probably be closed without review, and we may block users who do this
repeatedly or maliciously.

This is a rapidly-evolving situation that's causing significant shockwaves in the open-source community. As a result,
this policy is likely to be updated regularly in the near future. For more information, please read CONTRIBUTING.md.

  • I confirm that this is not a pure code agent PR.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline and the
    Pull Request checks?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes according to the guidelines?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@IlyasMoutawwakil IlyasMoutawwakil changed the title FP8: Honor the quant config's scale format FP8: Honor the quant config's scale format and fix EP Jun 22, 2026
@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@IlyasMoutawwakil IlyasMoutawwakil marked this pull request as ready for review June 22, 2026 19:52
@IlyasMoutawwakil IlyasMoutawwakil changed the title FP8: Honor the quant config's scale format and fix EP EP+FP8: fix EP router contract for many models and honor FP8 scale format Jun 22, 2026
return Fp8Quantize(self.hf_quantizer)


class Fp8DecodeScale(ConversionOps):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any ideas as to why this part was dropped ?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because i added support for ue8m0 scales in finegrained-fp8 v3, this was needed for minimax m3 with the v2, but not anymore, it also wastes memory

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ue8m0 scales are a bit messy, some store them in the correct torch dtype, some store them in uint8, and some even store them in fp32 for no special reason 😭 i'm trying to tighten the contract and honor the config all the times because supporting all the on-disk variations would be more complicated

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay ! Just to be sure, if we remove it now, it would not break existing checkpoints that are in mxpf8 format right ?

@IlyasMoutawwakil IlyasMoutawwakil Jun 23, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no they will work fine, even better because I just noticed that the fp32 scales are even avoiding the optimized mxfp8 path in https://github.com/huggingface/kernels-community/blob/aeb8ef0e09a132a6583c0a4c8b1096292922b54a/finegrained-fp8/torch-ext/finegrained_fp8/utils.py#L64 I also ran minimax m3 integration tests on the b200

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep sounds good, just require the version of the kernel for that path to error out properly if kernel version not installed

@IlyasMoutawwakil IlyasMoutawwakil Jun 24, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do pin the v3 in our lazy loading

intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekOcr2TextMLP(config=config, intermediate_size=intermediate_size)
self.n_routed_experts = config.n_routed_experts
self.num_experts = config.n_routed_experts

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

redundancy in variables ?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeh i guess we can drop n_routed_experts, removing it

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm so it seems to cascade into many models

Comment thread tests/test_tensor_parallel_mixin.py Outdated

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe better to add _skip_if_ep_not_supported here instead of within test_ep_*?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tell me if this works for you 0288a11

@vasqu vasqu left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only checked the modeling parts re modular and the models themself. It is slightly breaking technically because we move parts around modules so let's add 🚨

Generally aligned with this, just a bit unsure about the minimax m3 change - are we keeping everything as is without dequanting and then only convertin after all conversions? Not sure I can follow there 100%

Comment thread src/transformers/models/deepseek_v2/modular_deepseek_v2.py Outdated
Comment thread src/transformers/models/deepseek_v2/modular_deepseek_v2.py Outdated
Comment thread src/transformers/models/deepseek_v2/modular_deepseek_v2.py Outdated
Comment thread src/transformers/models/deepseek_v2/modular_deepseek_v2.py Outdated
Comment thread src/transformers/models/deepseek_v3/modular_deepseek_v3.py Outdated
Comment thread src/transformers/models/lfm2_moe/modular_lfm2_moe.py Outdated
Comment thread src/transformers/models/longcat_flash/modeling_longcat_flash.py Outdated
Comment thread src/transformers/models/minimax_m3_vl/modular_minimax_m3_vl.py
Comment thread src/transformers/models/mistral4/modular_mistral4.py Outdated
Comment thread src/transformers/models/solar_open/modular_solar_open.py Outdated
@IlyasMoutawwakil IlyasMoutawwakil changed the title EP+FP8: fix EP router contract for many models and honor FP8 scale format 🚨 EP: fix EP router contract for many models + honor FP8 scale format Jun 23, 2026

@vasqu vasqu left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some quick comments, ping me when it's ready for another review. Seems like some comments were resolved but not addressed?

Comment thread src/transformers/models/lfm2_moe/modular_lfm2_moe.py Outdated
Comment thread src/transformers/models/deepseek_v2/modular_deepseek_v2.py Outdated
Comment on lines +1461 to +1466
if not self._ep_plan:
raise ValueError(
f"Expert parallelism was requested (`enable_expert_parallel=True`), but "
f"`{self.__class__.__name__}` does not define an expert-parallel plan. Add a "
f"`base_model_ep_plan` to its config, or disable expert parallelism."
)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loud failure on missing ep plan

index_head_dim: int = 128
index_n_heads: int = 64
mlp_bias: bool = False
num_experts: int = 256

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was confusing

Comment on lines +137 to +156
def _process_model_after_weight_loading(self, model, **kwargs):
# dsv4-flash-base stores its (power-of-two) ue8m0 scales in a float32 container under
# `.scale`; those renamed keys keep the on-disk float32 dtype, so cast them to the UE8M0
# dtype the kernels expect (exact, since the values are powers of two). Checkpoints that
# already ship the native float8 E8M0 dtype (e.g. dsv4-flash) are left untouched.
if self.quantization_config.scale_fmt == "ue8m0":
from ..integrations.finegrained_fp8 import _get_ue8m0_dtype

ue8m0 = _get_ue8m0_dtype()
float32_scales = [
name
for name, param in model.named_parameters()
if name.endswith("_scale_inv") and param.dtype == torch.float32
]
for name in float32_scales:
module_name, _, attr = name.rpartition(".")
module = model.get_submodule(module_name)
scale = getattr(module, attr)
setattr(module, attr, torch.nn.Parameter(scale.data.to(ue8m0), requires_grad=False))
return model

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

either like this or by hooking a quantization op to the scale rename op

@IlyasMoutawwakil IlyasMoutawwakil Jun 24, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay the second option didn't work

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kinda prefer with a fp8DecodeScale

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does it have to be post proc?

@IlyasMoutawwakil IlyasMoutawwakil Jun 24, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fp8DecodeScale did the opposite, and it targeted mxfp8 where it converted truly ue8m0 to fp32,
this is for for dsv4-flash-base, we need the opposite, ie convert fp32 to ue8m0 to honor the config scale_fmt (because for some reason they stored their ue8ù0 scales in fp32😭), that way we avoid casting, with a new mem allocation, at the entry of each kernel.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does it have to be post proc?

because the rename catches the dsv4 flash base scales first

@vasqu vasqu left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok so I think this looks overall good now, just a few smaller comments. Sometimes we add an attribute mapping so that all variations are kind of covered, not sure if we really need it for all models (would just double check)

The quants re minimax m3 were checked re dequant and quant so I think we are good with the changes but would like to hear @ArthurZucker's opinion on those related changes

Comment thread src/transformers/models/deepseek_ocr2/configuration_deepseek_ocr2.py Outdated
Comment thread src/transformers/models/deepseek_ocr2/modular_deepseek_ocr2.py
Comment thread src/transformers/models/deepseek_v3/configuration_deepseek_v3.py
Comment thread src/transformers/models/deepseek_v3/modular_deepseek_v3.py
Comment thread src/transformers/models/deepseek_v32/configuration_deepseek_v32.py
Comment thread src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py Outdated
Comment thread src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py Outdated
@vasqu

vasqu commented Jun 24, 2026

Copy link
Copy Markdown
Collaborator

Let's also update the PR description please so we summarize the changes a bit

  1. FP8 scale changes
  2. EP Plans for all moes
    • Refactor along all models to follow the same format as router/gate -> experts (-> shared experts)
    • Additional miscallenous stuff like erroring out on moes that should have the plan

@ArthurZucker ArthurZucker left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, let's make sure kernel V is enforced

return Fp8Quantize(self.hf_quantizer)


class Fp8DecodeScale(ConversionOps):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep sounds good, just require the version of the kernel for that path to error out properly if kernel version not installed

if self.layer_types is None:
self.layer_types = ["deepseek_sparse_attention"] * self.num_hidden_layers

if (num_experts := kwargs.get("num_experts")) is not None:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mmm is this really something we want? let's not warn no?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We had 2 values n_routed_experts and num_experts so it's for BC in any case a user explicitly sets this

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the warning, it could indeed trigger unnecessarily

Comment on lines -103 to -107
"layers.*.mlp.experts.gate_up_proj": "grouped_gemm",
"layers.*.mlp.experts.gate_up_proj_scale_inv": "grouped_gemm",
"layers.*.mlp.experts.down_proj": "grouped_gemm",
"layers.*.mlp.experts.down_proj_scale_inv": "grouped_gemm",
"layers.*.mlp.experts": "moe_tp_experts",

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ow shit IDK how this slipped in !

del self.topk_method
self.norm_topk_prob = config.norm_topk_prob

def forward(self, hidden_states):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can probably push standards but its fine

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(meaning other models do this as well exactly potentially?)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what you mean here? It's the same as dsv2 (with a slightly different forward --> no norming at the end of the probs)

Comment on lines +137 to +156
def _process_model_after_weight_loading(self, model, **kwargs):
# dsv4-flash-base stores its (power-of-two) ue8m0 scales in a float32 container under
# `.scale`; those renamed keys keep the on-disk float32 dtype, so cast them to the UE8M0
# dtype the kernels expect (exact, since the values are powers of two). Checkpoints that
# already ship the native float8 E8M0 dtype (e.g. dsv4-flash) are left untouched.
if self.quantization_config.scale_fmt == "ue8m0":
from ..integrations.finegrained_fp8 import _get_ue8m0_dtype

ue8m0 = _get_ue8m0_dtype()
float32_scales = [
name
for name, param in model.named_parameters()
if name.endswith("_scale_inv") and param.dtype == torch.float32
]
for name in float32_scales:
module_name, _, attr = name.rpartition(".")
module = model.get_submodule(module_name)
scale = getattr(module, attr)
setattr(module, attr, torch.nn.Parameter(scale.data.to(ue8m0), requires_grad=False))
return model

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kinda prefer with a fp8DecodeScale

Comment on lines +137 to +156
def _process_model_after_weight_loading(self, model, **kwargs):
# dsv4-flash-base stores its (power-of-two) ue8m0 scales in a float32 container under
# `.scale`; those renamed keys keep the on-disk float32 dtype, so cast them to the UE8M0
# dtype the kernels expect (exact, since the values are powers of two). Checkpoints that
# already ship the native float8 E8M0 dtype (e.g. dsv4-flash) are left untouched.
if self.quantization_config.scale_fmt == "ue8m0":
from ..integrations.finegrained_fp8 import _get_ue8m0_dtype

ue8m0 = _get_ue8m0_dtype()
float32_scales = [
name
for name, param in model.named_parameters()
if name.endswith("_scale_inv") and param.dtype == torch.float32
]
for name in float32_scales:
module_name, _, attr = name.rpartition(".")
module = model.get_submodule(module_name)
scale = getattr(module, attr)
setattr(module, attr, torch.nn.Parameter(scale.data.to(ue8m0), requires_grad=False))
return model

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does it have to be post proc?

parallelism = "Expert" if expert_parallel else "Tensor"
# An EP-capable MoE (@use_experts_implementation) must ship an ep_plan; assert before any
# skip so a plan-less model fails even where the parallel test can't run (GPU, old torch).
if expert_parallel and self._get_tp_model_class()._can_set_experts_implementation():

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perfect, we want good default EP plan evailable

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, we can also make use_experts_impl take care of adding the ep_plan to the config at model init time for example

@github-actions

Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: afmoe, cohere2_moe, deepseek_ocr2, deepseek_v2, deepseek_v3, deepseek_v32, dots1, ernie4_5_moe, ernie4_5_vl_moe, exaone_moe, flex_olmo, glm4_moe, glm4_moe_lite, glm4v_moe, glm_moe_dsa, hunyuan_v1_moe

@github-actions

Copy link
Copy Markdown
Contributor

CI Dashboard: View test results in Grafana

@vasqu vasqu added this pull request to the merge queue Jun 24, 2026
@vasqu vasqu removed this pull request from the merge queue due to a manual request Jun 24, 2026
@vasqu

vasqu commented Jun 24, 2026

Copy link
Copy Markdown
Collaborator

Need to check whether we need to update various conversion mappings; so withholding to merge for now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants